{
 "cells": [
  {
   "cell_type": "raw",
   "metadata": {},
   "source": [
    "多输入通道和多少输出通道\n",
    "channel"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {},
   "source": [
    "多输入通道"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import d2lzh as d2l\n",
    "from mxnet import nd\n",
    "\n",
    "def corr2d_multi_in(X, K):\n",
    "    return nd.add_n(*[d2l.corr2d(x, k) for x, k in zip(X, K)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\n",
       "[[ 56.  72.]\n",
       " [104. 120.]]\n",
       "<NDArray 2x2 @cpu(0)>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X = nd.array([[[0, 1, 2],\n",
    "               [3, 4, 5],\n",
    "               [6, 7, 8]],\n",
    "             [[1, 2 ,3],\n",
    "             [4, 5, 6],\n",
    "             [7, 8, 9]]])\n",
    "K = nd.array([[[0,1],\n",
    "              [2, 3]],\n",
    "            [[1, 2],\n",
    "            [3, 4]]])\n",
    "corr2d_multi_in(X, K)"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {},
   "source": [
    "多输出通道"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def corr2d_multi_in_out(X, K):\n",
    "    return nd.stack(*[corr2d_multi_in(X, k) for k in K])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(3, 2, 2, 2)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "K = nd.stack(K, K + 1, K + 2)\n",
    "K.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\n",
       "[[[ 56.  72.]\n",
       "  [104. 120.]]\n",
       "\n",
       " [[ 76. 100.]\n",
       "  [148. 172.]]\n",
       "\n",
       " [[ 96. 128.]\n",
       "  [192. 224.]]]\n",
       "<NDArray 3x2x2 @cpu(0)>"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "corr2d_multi_in_out(X, K)"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {},
   "source": [
    "1*1卷积层\n",
    "conv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def corr2d_multi_in_out_1x1(X, K):\n",
    "    c_i, h, w = X.shape\n",
    "    c_o = K.shape[0]\n",
    "    X = X.reshape(c_i, h*w)\n",
    "    K = K.reshape(c_o, c_i)\n",
    "    Y = nd.dot(K, X) #全连接的矩阵乘法\n",
    "    return Y.reshape((c_o, h, w))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X = nd.random.uniform(shape=(3, 3, 3))\n",
    "K = nd.random.uniform(shape=(2, 3, 1, 1))\n",
    "\n",
    "Y1 = corr2d_multi_in_out_1x1(X, K)\n",
    "Y2 = corr2d_multi_in_out(X, K)\n",
    "\n",
    "(Y1 - Y2).norm().asscalar() < 1e-6"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\n",
       "[[[0.63251853 0.70270896 0.8923846 ]\n",
       "  [1.0109798  0.70354044 0.9426006 ]\n",
       "  [0.65125    0.909348   0.68587   ]]\n",
       "\n",
       " [[0.45974678 0.47996217 0.3274668 ]\n",
       "  [0.37361276 0.25606534 0.66247874]\n",
       "  [0.09572    0.7063811  0.23235497]]]\n",
       "<NDArray 2x3x3 @cpu(0)>"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Y1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\n",
       "[[[0.63251853 0.70270896 0.8923846 ]\n",
       "  [1.0109798  0.70354044 0.9426006 ]\n",
       "  [0.65125    0.909348   0.68587   ]]\n",
       "\n",
       " [[0.45974678 0.47996217 0.3274668 ]\n",
       "  [0.37361276 0.2560653  0.66247874]\n",
       "  [0.09572    0.7063811  0.23235497]]]\n",
       "<NDArray 2x3x3 @cpu(0)>"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Y2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\n",
       "[[[1. 1. 1.]\n",
       "  [1. 1. 1.]\n",
       "  [1. 1. 1.]]\n",
       "\n",
       " [[1. 1. 1.]\n",
       "  [1. 0. 1.]\n",
       "  [1. 1. 1.]]]\n",
       "<NDArray 2x3x3 @cpu(0)>"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Y1 == Y2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
